{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Practice: Apply the chain rule"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Define a custom expression node `Derivative(expr, v)` that symbolically represents taking a derivative of an expression `expr` with respect to variable `v`.\n",
    "1. Now suppose that, in order to take a derivative by a coordinate `x` (given), what your code actually has to do is consider the derivative in a *reference coordinate system* consisting of coordinates `r` and `s` and therefore needs to apply the chain rule identity\n",
    "\n",
    "$$ \\frac{d\\text{expr}}{dx} = \\frac{d\\text{expr}}{dr}\\frac{dr}{dx} + \\frac{d\\text{expr}}{ds}\\frac{ds}{dx}$$\n",
    "\n",
    "Write a `ChainRuleMapper` that applies this identity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from pymbolic import var\n",
    "from pymbolic.primitives import Expression\n",
    "from pymbolic.mapper import IdentityMapper\n",
    "\n",
    "x = var(\"x\")\n",
    "r = var(\"r\")\n",
    "s = var(\"s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class Derivative(Expression):\n",
    "    # ...\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To avoid conflicts with a `Derivative` node type that's already part of pymbolic, we call our mapper method `map_deriv`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#clear\n",
    "# Solution\n",
    "\n",
    "class Derivative(Expression):\n",
    "    def __init__(self, expr, v):\n",
    "        self.expr = expr\n",
    "        self.v = v\n",
    "\n",
    "    def __getinitargs__(self):\n",
    "        return (self.expr, self.v)\n",
    "\n",
    "    mapper_method = \"map_deriv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Call(Variable('sqrt'), (Derivative(Sum((Product((27, Power(Variable('x'), 2))), Call(Variable('exp'), (Variable('x'),)))), Variable('x')),))\n"
     ]
    }
   ],
   "source": [
    "expr = var(\"sqrt\")(Derivative(27*x**2+var(\"exp\")(x), x))\n",
    "print(repr(expr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class ChainRuleMapper(IdentityMapper):\n",
    "    # ...\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#clear\n",
    "# Solution\n",
    "\n",
    "class ChainRuleMapper(IdentityMapper):\n",
    "    def map_deriv(self, expr):\n",
    "        return sum(Derivative(expr, ref_sym)*Derivative(ref_sym, x) for ref_sym in [r,s])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's test this mapper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Call(Variable('sqrt'), (Sum((Product((Derivative(Derivative(Sum((Product((27, Power(Variable('x'), 2))), Call(Variable('exp'), (Variable('x'),)))), Variable('x')), Variable('r')), Derivative(Variable('r'), Variable('x')))), Product((Derivative(Derivative(Sum((Product((27, Power(Variable('x'), 2))), Call(Variable('exp'), (Variable('x'),)))), Variable('x')), Variable('s')), Derivative(Variable('s'), Variable('x')))))),))"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "crm = ChainRuleMapper()\n",
    "crm(expr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In case you are wondering why we can only use the 'clumsy', parenthesis-heavy form of the printed expression, it's because we haven't told pymbolic how to write out the shorter form. Here's how that can be done:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sqrt(d(d((27*x**2 + exp(x)))/dx)/dr*d(r)/dx + d(d((27*x**2 + exp(x)))/dx)/ds*d(s)/dx)\n"
     ]
    }
   ],
   "source": [
    "from pymbolic.mapper.stringifier import StringifyMapper, PREC_PRODUCT\n",
    "\n",
    "class MyStringifyMapper(StringifyMapper):\n",
    "    def map_deriv(self, expr, enclosing_prec):\n",
    "        return \"d(%s)/d%s\" % (\n",
    "            self.rec(expr.expr, PREC_PRODUCT), \n",
    "            self.rec(expr.v, PREC_PRODUCT))\n",
    "    \n",
    "def stringifier(self):\n",
    "    return MyStringifyMapper\n",
    "\n",
    "Derivative.stringifier = stringifier\n",
    "print(crm(expr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.0+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
